import torch
import torchvision
import torch.nn as nn

if torch.cuda.is_available():
    print("Working on GPU")
else:
    print("Working on CPU")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def CreateResNet():
    resnet = torchvision.models.resnet18(pretrained=True).to(device)
    resnet.fc = torch.nn.Linear(in_features=512,
                               out_features=10, bias=True).to(device)
    for name, param in resnet.named_parameters():
        if name[5] < '2' or name[2] < '2' or name[4] < '2':
            param.requires_grad = False
    #resnet = nn.DataParallel(resnet)
    return resnet